Newer
Older
BlackoutClient / Assets / Best HTTP / Source / SecureProtocol / crypto / tls / DtlsRecordLayer.cs
#if !BESTHTTP_DISABLE_ALTERNATE_SSL && (!UNITY_WEBGL || UNITY_EDITOR)
#pragma warning disable
using System;
using System.IO;

using BestHTTP.SecureProtocol.Org.BouncyCastle.Utilities.Date;

namespace BestHTTP.SecureProtocol.Org.BouncyCastle.Crypto.Tls
{
    internal class DtlsRecordLayer
        :   DatagramTransport
    {
        private const int RECORD_HEADER_LENGTH = 13;
        private const int MAX_FRAGMENT_LENGTH = 1 << 14;
        private const long TCP_MSL = 1000L * 60 * 2;
        private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2;

        private readonly DatagramTransport mTransport;
        private readonly TlsContext mContext;
        private readonly TlsPeer mPeer;

        private readonly ByteQueue mRecordQueue = new ByteQueue();

        private volatile bool mClosed = false;
        private volatile bool mFailed = false;
        private volatile ProtocolVersion mReadVersion = null, mWriteVersion = null;
        private volatile bool mInHandshake;
        private volatile int mPlaintextLimit;
        private DtlsEpoch mCurrentEpoch, mPendingEpoch;
        private DtlsEpoch mReadEpoch, mWriteEpoch;

        private DtlsHandshakeRetransmit mRetransmit = null;
        private DtlsEpoch mRetransmitEpoch = null;
        private long mRetransmitExpiry = 0;

        internal DtlsRecordLayer(DatagramTransport transport, TlsContext context, TlsPeer peer, byte contentType)
        {
            this.mTransport = transport;
            this.mContext = context;
            this.mPeer = peer;

            this.mInHandshake = true;

            this.mCurrentEpoch = new DtlsEpoch(0, new TlsNullCipher(context));
            this.mPendingEpoch = null;
            this.mReadEpoch = mCurrentEpoch;
            this.mWriteEpoch = mCurrentEpoch;

            SetPlaintextLimit(MAX_FRAGMENT_LENGTH);
        }

        internal bool IsClosed
        {
            get { return mClosed; }
        }

        internal virtual void SetPlaintextLimit(int plaintextLimit)
        {
            this.mPlaintextLimit = plaintextLimit;
        }

        internal virtual int ReadEpoch
        {
            get { return mReadEpoch.Epoch; }
        }

        internal virtual ProtocolVersion ReadVersion
        {
            get { return mReadVersion; }
            set { this.mReadVersion = value; }
        }

        internal virtual void SetWriteVersion(ProtocolVersion writeVersion)
        {
            this.mWriteVersion = writeVersion;
        }

        internal virtual void InitPendingEpoch(TlsCipher pendingCipher)
        {
            if (mPendingEpoch != null)
                throw new InvalidOperationException();

            /*
             * TODO "In order to ensure that any given sequence/epoch pair is unique, implementations
             * MUST NOT allow the same epoch value to be reused within two times the TCP maximum segment
             * lifetime."
             */

            // TODO Check for overflow
            this.mPendingEpoch = new DtlsEpoch(mWriteEpoch.Epoch + 1, pendingCipher);
        }

        internal virtual void HandshakeSuccessful(DtlsHandshakeRetransmit retransmit)
        {
            if (mReadEpoch == mCurrentEpoch || mWriteEpoch == mCurrentEpoch)
            {
                // TODO
                throw new InvalidOperationException();
            }

            if (retransmit != null)
            {
                this.mRetransmit = retransmit;
                this.mRetransmitEpoch = mCurrentEpoch;
                this.mRetransmitExpiry = DateTimeUtilities.CurrentUnixMs() + RETRANSMIT_TIMEOUT;
            }

            this.mInHandshake = false;
            this.mCurrentEpoch = mPendingEpoch;
            this.mPendingEpoch = null;
        }

        internal virtual void ResetWriteEpoch()
        {
            if (mRetransmitEpoch != null)
            {
                this.mWriteEpoch = mRetransmitEpoch;
            }
            else
            {
                this.mWriteEpoch = mCurrentEpoch;
            }
        }

        public virtual int GetReceiveLimit()
        {
            return System.Math.Min(this.mPlaintextLimit,
                mReadEpoch.Cipher.GetPlaintextLimit(mTransport.GetReceiveLimit() - RECORD_HEADER_LENGTH));
        }

        public virtual int GetSendLimit()
        {
            return System.Math.Min(this.mPlaintextLimit,
                mWriteEpoch.Cipher.GetPlaintextLimit(mTransport.GetSendLimit() - RECORD_HEADER_LENGTH));
        }

        public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
        {
            byte[] record = null;

            for (;;)
            {
                int receiveLimit = System.Math.Min(len, GetReceiveLimit()) + RECORD_HEADER_LENGTH;
                if (record == null || record.Length < receiveLimit)
                {
                    record = new byte[receiveLimit];
                }

                try
                {
                    if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry)
                    {
                        mRetransmit = null;
                        mRetransmitEpoch = null;
                    }

                    int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
                    if (received < 0)
                    {
                        return received;
                    }
                    if (received < RECORD_HEADER_LENGTH)
                    {
                        continue;
                    }
                    int length = TlsUtilities.ReadUint16(record, 11);
                    if (received != (length + RECORD_HEADER_LENGTH))
                    {
                        continue;
                    }

                    byte type = TlsUtilities.ReadUint8(record, 0);

                    // TODO Support user-specified custom protocols?
                    switch (type)
                    {
                    case ContentType.alert:
                    case ContentType.application_data:
                    case ContentType.change_cipher_spec:
                    case ContentType.handshake:
                    case ContentType.heartbeat:
                        break;
                    default:
                        // TODO Exception?
                        continue;
                    }

                    int epoch = TlsUtilities.ReadUint16(record, 3);

                    DtlsEpoch recordEpoch = null;
                    if (epoch == mReadEpoch.Epoch)
                    {
                        recordEpoch = mReadEpoch;
                    }
                    else if (type == ContentType.handshake && mRetransmitEpoch != null
                        && epoch == mRetransmitEpoch.Epoch)
                    {
                        recordEpoch = mRetransmitEpoch;
                    }

                    if (recordEpoch == null)
                    {
                        continue;
                    }

                    long seq = TlsUtilities.ReadUint48(record, 5);
                    if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
                    {
                        continue;
                    }

                    ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
                    if (!version.IsDtls)
                    {
                        continue;
                    }

                    if (mReadVersion != null && !mReadVersion.Equals(version))
                    {
                        continue;
                    }

                    byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext(
                        GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH,
                        received - RECORD_HEADER_LENGTH);

                    recordEpoch.ReplayWindow.ReportAuthenticated(seq);

                    if (plaintext.Length > this.mPlaintextLimit)
                    {
                        continue;
                    }

                    if (mReadVersion == null)
                    {
                        mReadVersion = version;
                    }

                    switch (type)
                    {
                    case ContentType.alert:
                    {
                        if (plaintext.Length == 2)
                        {
                            byte alertLevel = plaintext[0];
                            byte alertDescription = plaintext[1];

                            mPeer.NotifyAlertReceived(alertLevel, alertDescription);

                            if (alertLevel == AlertLevel.fatal)
                            {
                                Failed();
                                throw new TlsFatalAlert(alertDescription);
                            }

                            // TODO Can close_notify be a fatal alert?
                            if (alertDescription == AlertDescription.close_notify)
                            {
                                CloseTransport();
                            }
                        }

                        continue;
                    }
                    case ContentType.application_data:
                    {
                        if (mInHandshake)
                        {
                            // TODO Consider buffering application data for new epoch that arrives
                            // out-of-order with the Finished message
                            continue;
                        }
                        break;
                    }
                    case ContentType.change_cipher_spec:
                    {
                        // Implicitly receive change_cipher_spec and change to pending cipher state

                        for (int i = 0; i < plaintext.Length; ++i)
                        {
                            byte message = TlsUtilities.ReadUint8(plaintext, i);
                            if (message != ChangeCipherSpec.change_cipher_spec)
                            {
                                continue;
                            }

                            if (mPendingEpoch != null)
                            {
                                mReadEpoch = mPendingEpoch;
                            }
                        }

                        continue;
                    }
                    case ContentType.handshake:
                    {
                        if (!mInHandshake)
                        {
                            if (mRetransmit != null)
                            {
                                mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length);
                            }

                            // TODO Consider support for HelloRequest
                            continue;
                        }
                        break;
                    }
                    case ContentType.heartbeat:
                    {
                        // TODO[RFC 6520]
                        continue;
                    }
                    }

                    /*
                     * NOTE: If we receive any non-handshake data in the new epoch implies the peer has
                     * received our final flight.
                     */
                    if (!mInHandshake && mRetransmit != null)
                    {
                        this.mRetransmit = null;
                        this.mRetransmitEpoch = null;
                    }

                    Array.Copy(plaintext, 0, buf, off, plaintext.Length);
                    return plaintext.Length;
                }
                catch (IOException e)
                {
                    // NOTE: Assume this is a timeout for the moment
                    throw e;
                }
            }
        }

        /// <exception cref="IOException"/>
        public virtual void Send(byte[] buf, int off, int len)
        {
            byte contentType = ContentType.application_data;

            if (this.mInHandshake || this.mWriteEpoch == this.mRetransmitEpoch)
            {
                contentType = ContentType.handshake;

                byte handshakeType = TlsUtilities.ReadUint8(buf, off);
                if (handshakeType == HandshakeType.finished)
                {
                    DtlsEpoch nextEpoch = null;
                    if (this.mInHandshake)
                    {
                        nextEpoch = mPendingEpoch;
                    }
                    else if (this.mWriteEpoch == this.mRetransmitEpoch)
                    {
                        nextEpoch = mCurrentEpoch;
                    }

                    if (nextEpoch == null)
                    {
                        // TODO
                        throw new InvalidOperationException();
                    }

                    // Implicitly send change_cipher_spec and change to pending cipher state

                    // TODO Send change_cipher_spec and finished records in single datagram?
                    byte[] data = new byte[]{ 1 };
                    SendRecord(ContentType.change_cipher_spec, data, 0, data.Length);

                    mWriteEpoch = nextEpoch;
                }
            }

            SendRecord(contentType, buf, off, len);
        }

        public virtual void Close()
        {
            if (!mClosed)
            {
                if (mInHandshake)
                {
                    Warn(AlertDescription.user_canceled, "User canceled handshake");
                }
                CloseTransport();
            }
        }

        internal virtual void Failed()
        {
            if (!mClosed)
            {
                mFailed = true;

                CloseTransport();
            }
        }

        internal virtual void Fail(byte alertDescription)
        {
            if (!mClosed)
            {
                try
                {
                    RaiseAlert(AlertLevel.fatal, alertDescription, null, null);
                }
                catch (Exception)
                {
                    // Ignore
                }

                mFailed = true;

                CloseTransport();
            }
        }

        internal virtual void Warn(byte alertDescription, string message)
        {
            RaiseAlert(AlertLevel.warning, alertDescription, message, null);
        }

        private void CloseTransport()
        {
            if (!mClosed)
            {
                /*
                 * RFC 5246 7.2.1. Unless some other fatal alert has been transmitted, each party is
                 * required to send a close_notify alert before closing the write side of the
                 * connection. The other party MUST respond with a close_notify alert of its own and
                 * close down the connection immediately, discarding any pending writes.
                 */

                try
                {
                    if (!mFailed)
                    {
                        Warn(AlertDescription.close_notify, null);
                    }
                    mTransport.Close();
                }
                catch (Exception)
                {
                    // Ignore
                }

                mClosed = true;
            }
        }

        private void RaiseAlert(byte alertLevel, byte alertDescription, string message, Exception cause)
        {
            mPeer.NotifyAlertRaised(alertLevel, alertDescription, message, cause);

            byte[] error = new byte[2];
            error[0] = (byte)alertLevel;
            error[1] = (byte)alertDescription;

            SendRecord(ContentType.alert, error, 0, 2);
        }

        private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis)
        {
            if (mRecordQueue.Available > 0)
            {
                int length = 0;
                if (mRecordQueue.Available >= RECORD_HEADER_LENGTH)
                {
                    byte[] lengthBytes = new byte[2];
                    mRecordQueue.Read(lengthBytes, 0, 2, 11);
                    length = TlsUtilities.ReadUint16(lengthBytes, 0);
                }

                int received = System.Math.Min(mRecordQueue.Available, RECORD_HEADER_LENGTH + length);
                mRecordQueue.RemoveData(buf, off, received, 0);
                return received;
            }

            {
                int received = mTransport.Receive(buf, off, len, waitMillis);
                if (received >= RECORD_HEADER_LENGTH)
                {
                    int fragmentLength = TlsUtilities.ReadUint16(buf, off + 11);
                    int recordLength = RECORD_HEADER_LENGTH + fragmentLength;
                    if (received > recordLength)
                    {
                        mRecordQueue.AddData(buf, off + recordLength, received - recordLength);
                        received = recordLength;
                    }
                }
                return received;
            }
        }

        private void SendRecord(byte contentType, byte[] buf, int off, int len)
        {
            // Never send anything until a valid ClientHello has been received
            if (mWriteVersion == null)
                return;

            if (len > this.mPlaintextLimit)
                throw new TlsFatalAlert(AlertDescription.internal_error);

            /*
             * RFC 5246 6.2.1 Implementations MUST NOT send zero-length fragments of Handshake, Alert,
             * or ChangeCipherSpec content types.
             */
            if (len < 1 && contentType != ContentType.application_data)
                throw new TlsFatalAlert(AlertDescription.internal_error);

            int recordEpoch = mWriteEpoch.Epoch;
            long recordSequenceNumber = mWriteEpoch.AllocateSequenceNumber();

            byte[] ciphertext = mWriteEpoch.Cipher.EncodePlaintext(
                GetMacSequenceNumber(recordEpoch, recordSequenceNumber), contentType, buf, off, len);

            // TODO Check the ciphertext length?

            byte[] record = new byte[ciphertext.Length + RECORD_HEADER_LENGTH];
            TlsUtilities.WriteUint8(contentType, record, 0);
            ProtocolVersion version = mWriteVersion;
            TlsUtilities.WriteVersion(version, record, 1);
            TlsUtilities.WriteUint16(recordEpoch, record, 3);
            TlsUtilities.WriteUint48(recordSequenceNumber, record, 5);
            TlsUtilities.WriteUint16(ciphertext.Length, record, 11);
            Array.Copy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.Length);

            mTransport.Send(record, 0, record.Length);
        }

        private static long GetMacSequenceNumber(int epoch, long sequence_number)
        {
            return ((epoch & 0xFFFFFFFFL) << 48) | sequence_number;
        }
    }
}
#pragma warning restore
#endif